//+------------------------------------------------------------------+
//|                                    CandlestickPatterns AI-EA.mq5 |
//|                                          Copyright 2023, Omegafx |
//|                 https://www.mql5.com/en/users/omegajoctan/seller |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omegafx"
#property link      "https://www.mql5.com/en/users/omegajoctan/seller"
#property version   "1.00"

#include <Trade\Trade.mqh> //The trading module
#include <Trade\PositionInfo.mqh> //Position handling module
#include <ta-lib.mqh> //For candlestick patterns
#include <Catboost.mqh> //Has a class for deploying a catboost model

CTrade m_trade;
CPositionInfo m_position;

CCatboostClassifier catboost;

input int magic_number = 21042025;
input int slippage = 100;
input string symbol_ = "XAUUSD";
input ENUM_TIMEFRAMES timeframe_ = PERIOD_D1;
input int lookahead = 1;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
      
   if (!MQLInfoInteger(MQL_TESTER))
     if (!ChartSetSymbolPeriod(0, symbol_, timeframe_))
       {
         printf("%s failed to set symbol %s and timeframe %s, Check these values. Err = %d",__FUNCTION__,symbol_,EnumToString(timeframe_),GetLastError());
         return INIT_FAILED;
       }
  
//---
   
   if (!catboost.Init(StringFormat("CatBoost.CDLPatterns.%s.onnx",symbol_), ONNX_COMMON_FOLDER)) //Initialize the catboost model
      return INIT_FAILED;
   
//---

   m_trade.SetExpertMagicNumber(magic_number);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetMarginMode();
   m_trade.SetTypeFillingBySymbol(Symbol());
           
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   
   double open = iOpen(Symbol(), Period(), 1),
          high = iHigh(Symbol(), Period(), 1),
          low  = iLow(Symbol(), Period(), 1), 
          close = iClose(Symbol(), Period(), 1);
   
   
   vector x = {
               CTALib::CDLWHITECANDLE(open, close),
               CTALib::CDLBLACKCANDLE(open, close),
               CTALib::CDLDOJI(open, close),
               CTALib::CDLDRAGONFLYDOJI(open, high, low, close),
               CTALib::CDLGRAVESTONEDOJI(open, high, low, close),
               CTALib::CDLHAMMER(open, high, low, close),
               CTALib::CDLINVERTEDHAMMER(open, high, low, close),
               CTALib::CDLSPINNINGTOP(open, high, low, close),
               CTALib::CDLBULLISHMARUBOZU(open, high, low, close),
               CTALib::CDLBEARISHMARUBOZU(open, high, low, close)
              };
   
   vector patterns = {
                        CTALib::CDLDOJI(open, close),
                        CTALib::CDLDRAGONFLYDOJI(open, high, low, close),
                        CTALib::CDLGRAVESTONEDOJI(open, high, low, close),
                        CTALib::CDLHAMMER(open, high, low, close),
                        CTALib::CDLINVERTEDHAMMER(open, high, low, close),
                        CTALib::CDLSPINNINGTOP(open, high, low, close),
                        CTALib::CDLBULLISHMARUBOZU(open, high, low, close),
                        CTALib::CDLBEARISHMARUBOZU(open, high, low, close)
                     }; //Store all the special patterns 
    
   long signal = catboost.predict(x).cls; //Predicted class

   MqlTick ticks;
   if (!SymbolInfoTick(Symbol(), ticks))
      {
         printf("Failed to obtain ticks information, Error = %d",GetLastError());
         return;
      }
      
   double volume_ = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
   
   
   if (signal == 1 && patterns.Sum()>0) //Check if there are is atleast a special pattern before opening a trade
     {        
        if (!PosExists(POSITION_TYPE_BUY) && !PosExists(POSITION_TYPE_SELL))  
            m_trade.Buy(volume_, Symbol(), ticks.ask,0,0);
     }
     
   if (signal == 0 && patterns.Sum()>0) //Check if there are is atleast a special pattern before opening a trade
     {        
        if (!PosExists(POSITION_TYPE_SELL) && !PosExists(POSITION_TYPE_BUY))  
            m_trade.Sell(volume_, Symbol(), ticks.bid,0,0);
     } 
    
    CloseTradeAfterTime((Timeframe2Minutes(Period())*lookahead)*60); //Close the trade after a certain lookahead and according the the trained timeframe
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == magic_number && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool ClosePos(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol() == Symbol() && m_position.Magic() == magic_number && m_position.PositionType()==type)
            {
              if (m_trade.PositionClose(m_position.Ticket()))
                return true;
            }
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void CloseTradeAfterTime(int period_seconds)
{
   for (int i = PositionsTotal() - 1; i >= 0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Magic() == magic_number)
            if (TimeCurrent() - m_position.Time() >= period_seconds)
               m_trade.PositionClose(m_position.Ticket(), slippage);
}
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
int Timeframe2Minutes(ENUM_TIMEFRAMES tf)
{
    switch(tf)
    {
        case PERIOD_M1:  return 1;
        case PERIOD_M2:  return 2;
        case PERIOD_M3:  return 3;
        case PERIOD_M4:  return 4;
        case PERIOD_M5:  return 5;
        case PERIOD_M6:  return 6;
        case PERIOD_M10: return 10;
        case PERIOD_M12: return 12;
        case PERIOD_M15: return 15;
        case PERIOD_M20: return 20;
        case PERIOD_M30: return 30;
        case PERIOD_H1:  return 60;
        case PERIOD_H2:  return 120;
        case PERIOD_H3:  return 180;
        case PERIOD_H4:  return 240;
        case PERIOD_H6:  return 360;
        case PERIOD_H8:  return 480;
        case PERIOD_H12: return 720;
        case PERIOD_D1:  return 1440; // 1 day = 1440 minutes
        case PERIOD_W1:  return 10080; // 1 week = 7 * 1440 minutes
        case PERIOD_MN1: return 43200; // Approx. 1 month = 30 * 1440 minutes

        default:
            PrintFormat("Unknown timeframe: %d", tf);
            return 0;
    }
}
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+

